{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lattice in MNIST\n",
    "In this tutorial, we'll show how to use lattice layer together with other layers such as neural networks.\n",
    "We will construct a neural network with 1 hidden layer for classifying hand-written digit, and then feed the output of neural network to the lattice layer to capture the possible interactions between output of neural network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import libraries\n",
    "import tensorflow as tf\n",
    "import tensorflow_lattice as tfl\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define helper functions\n",
    "\n",
    "# linear layer's output is output = w * input_tensor + bias.\n",
    "def _linear_layer(input_tensor, input_dim, output_dim):\n",
    "    w = tf.Variable(\n",
    "        tf.random_normal([input_dim, output_dim], mean=0.0, stddev=0.1))\n",
    "    b = tf.Variable(tf.zeros([output_dim]))\n",
    "    return tf.matmul(input_tensor, w) + b\n",
    "\n",
    "# The following function returns lattice parameters for the identity function\n",
    "# f(x1, x2, x3, ..., xn) = [x1, x2, ..., xn].\n",
    "def identity_lattice(lattice_sizes, dim=10):\n",
    "    linear_weights = []\n",
    "    for cnt in range(dim):\n",
    "        linear_weight = [0.0] * dim\n",
    "        linear_weight[cnt] = float(dim)\n",
    "        linear_weights.append(linear_weight)\n",
    "    lattice_params = tfl.python.lib.lattice_layers.lattice_param_as_linear(\n",
    "        lattice_sizes,\n",
    "        dim,\n",
    "        linear_weights=linear_weights)\n",
    "    for cnt1 in range(len(lattice_params)):\n",
    "        for cnt2 in range(len(lattice_params[cnt1])):\n",
    "            lattice_params[cnt1][cnt2] += 0.5\n",
    "      \n",
    "    return lattice_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "data_dir = '/tmp/tfl-data'\n",
    "\n",
    "# Mnist dataset contains a 28 x 28 (784) image of hand written digit and\n",
    "# a label in one-hot representation, i.e., if label == 0, it means the image\n",
    "# contains \"0\", etc. Since there are total 10 digits, the label is\n",
    "# a 10-dim vector.\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)\n",
    "train_batch_size = 1000\n",
    "learning_rate = 0.05\n",
    "num_steps = 3000\n",
    "\n",
    "# Placeholders for feeding the dataset.\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "y_ = tf.placeholder(tf.float32, [None, 10])\n",
    "                       \n",
    "# First hidden layer has 100 hidden units.\n",
    "hidden = tf.nn.relu(_linear_layer(x, 784, 100))\n",
    "# From 100 hidden units to the final 10 dim output.\n",
    "nn_y = _linear_layer(hidden, 100, 10)\n",
    "\n",
    "# We also construct a lattice layer.\n",
    "# We apply softmax to nn_y which converts the output of neural network to the\n",
    "# probability. So nn_y is in 10 dimensional probability simplex.\n",
    "# Then 2 x 2 x ... x 2 layer uses this as an input and make a final 10 dim\n",
    "# prediction.\n",
    "output_dim = 10\n",
    "lattice_sizes = [2] * output_dim\n",
    "\n",
    "# We initialize a lattice to be the identity function.\n",
    "lattice_init = identity_lattice(lattice_sizes=lattice_sizes, dim=output_dim)\n",
    "\n",
    "# Now we define 2 x 2 x ... x 2 lattice that uses tf.nn.softmax(nn_y) as an\n",
    "# input. This is the additional non-linearity.\n",
    "lattice_output, _, _, _ = tfl.lattice_layer(\n",
    "    tf.nn.softmax(nn_y),\n",
    "    lattice_sizes=lattice_sizes,\n",
    "    output_dim=output_dim,\n",
    "    lattice_initializer=lattice_init,\n",
    "    interpolation_type='hypercube')\n",
    "\n",
    "# loss function for training NN.\n",
    "nn_cross_entropy = tf.reduce_mean(\n",
    "    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=nn_y))\n",
    "\n",
    "# loss function for training lattice + NN jointly.\n",
    "lattice_cross_entropy = tf.reduce_mean(\n",
    "    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=lattice_output))\n",
    "\n",
    "# NN training step.\n",
    "nn_train_step = tf.train.AdamOptimizer(learning_rate).minimize(nn_cross_entropy)\n",
    "\n",
    "# lattice + NN training step.\n",
    "lattice_train_step = tf.train.AdamOptimizer(0.001).minimize(\n",
    "    lattice_cross_entropy)\n",
    "\n",
    "sess = tf.Session()\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "train_ops = {'train_step': nn_train_step, 'loss': nn_cross_entropy}\n",
    "lattice_train_ops = {'train_step': lattice_train_step,\n",
    "                     'loss': lattice_cross_entropy}\n",
    "\n",
    "print('Pre training NN')\n",
    "# Pre-train NN.\n",
    "for cnt in range(num_steps):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)\n",
    "    value_dict = sess.run(train_ops, feed_dict={x: batch_xs, y_: batch_ys})\n",
    "    if cnt % 1000 == 0:\n",
    "        print('loss=%f' % value_dict['loss'])\n",
    "\n",
    "\n",
    "# NN Accuracy\n",
    "correct_prediction = tf.equal(tf.argmax(lattice_output, 1),\n",
    "                              tf.argmax(y_, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "print('training accuracy')\n",
    "print(sess.run(accuracy, feed_dict={x: mnist.train.images,\n",
    "                                    y_: mnist.train.labels}))\n",
    "print('test accuracy')\n",
    "print(sess.run(accuracy, feed_dict={x: mnist.test.images,\n",
    "                                    y_: mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Lattice train')\n",
    "# Lattice + NN Train\n",
    "for cnt in range(num_steps):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)\n",
    "    value_dict = sess.run(lattice_train_ops, feed_dict={x: batch_xs,\n",
    "                                                        y_: batch_ys})\n",
    "    if cnt % 1000 == 0:\n",
    "        print('loss=%f' % value_dict['loss']) \n",
    "\n",
    "print('training accuracy')\n",
    "print(sess.run(accuracy, feed_dict={x: mnist.train.images,\n",
    "                                    y_: mnist.train.labels}))\n",
    "print('test accuracy')\n",
    "print(sess.run(accuracy, feed_dict={x: mnist.test.images,\n",
    "                                    y_: mnist.test.labels}))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-lattice",
   "language": "python",
   "name": "tf-lattice"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
